Описание¶

Задачи¶

Испольуется три задачи:

  • DM -- двухальтернативный выбор
  • Romo -- сравнение двух сигнало через задержку
  • CtxDM -- DM с контекстом Вход состоит из одного контекстного входа, одного стимула, 6 входов, кодирующих задачи. Выход как и раньше состоит из трех частей: контекстный выход, выходы принятия решения.

Сеть¶

Сеть состоит из lif AdEx нейронов

Импорт всех необходимых библиотек¶

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from cgtasknet.net.lifadex import SNNlifadex
from cgtasknet.tasks.reduce import (
    CtxDMTaskParameters,
    DMTaskParameters,
    DMTaskRandomModParameters,
    MultyReduceTasks,
    RomoTaskParameters,
    RomoTaskRandomModParameters,
)
from norse.torch.functional.lif_adex import LIFAdExParameters
from tqdm import tqdm

Определяем устройство¶

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"{device=}")
device=device(type='cuda', index=0)

Функция построения входов и выходов¶

In [3]:
import os


def plot_results(inputs, target_outputs, outputs):
    if isinstance(inputs, torch.Tensor) and isinstance(target_outputs, torch.Tensor):
        inputs, t_outputs = (
            inputs.detach().cpu().numpy(),
            target_outputs.detach().cpu().numpy(),
        )
    for bath in range(min(batch_size, 20)):
        fig = plt.figure(figsize=(15, 3))
        ax1 = fig.add_subplot(141)
        plt.title("Inputs")
        plt.xlabel("$time, ms$")
        plt.ylabel("$Magnitude$")
        for i in range(3):
            plt.plot(inputs[:, bath, i].T, label=rf"$in_{i + 1}$")
        plt.legend()
        plt.tight_layout()

        ax2 = fig.add_subplot(142)
        plt.title("Task code (context)")
        plt.xticks(np.arange(1, len(tasks) + 1), sorted(tasks), rotation=90)
        plt.yticks([])
        for i in range(3, inputs.shape[-1]):
            plt.plot([i - 2] * 2, [0, inputs[0, bath, i]])
        plt.tight_layout()

        ax3 = fig.add_subplot(143)
        plt.title("Target output")
        plt.xlabel("$time, ms$")
        for i in range(t_outputs.shape[-1]):
            plt.plot(t_outputs[:, bath, i], label=rf"$out_{i + 1}$")
        plt.legend()
        plt.tight_layout()

        ax4 = fig.add_subplot(144)
        plt.title("Real output")
        plt.xlabel("$time, ms$")
        for i in range(outputs.shape[-1]):
            plt.plot(
                outputs.detach().cpu().numpy()[:, bath, i], label=rf"$out_{i + 1}$"
            )
        plt.legend()
        plt.tight_layout()
        if not os.path.exists("figures"):
            os.mkdir("figures")
        plt.savefig(f"figures{os.sep}network_outputs_{name}_batch_{bath}.pdf")
        plt.show()
        plt.close()

Определяем датасет¶

Датасет будет состоять из трех типов задач:

  • DM задача;
  • Romo задача;
  • CtxDM задача. Параметры для последней задачи аналогичны DM задаче

Параметры датасета:¶

In [4]:
batch_size = 100
number_of_epochs = 2000
number_of_tasks = 1
romo_parameters = RomoTaskRandomModParameters(
    romo=RomoTaskParameters(
        delay=0.1,
        positive_shift_delay_time=1.4,
        trial_time=0.1,
        positive_shift_trial_time=0.2,
    ),
)
dm_parameters = DMTaskRandomModParameters(
    dm=DMTaskParameters(trial_time=0.1, positive_shift_trial_time=0.8)
)
ctx_parameters = CtxDMTaskParameters(dm=dm_parameters.dm)

Датасет¶

In [5]:
sigma = 0.1
tasks = ["RomoTask1", "RomoTask2", "DMTask1", "DMTask2", "CtxDMTask1", "CtxDMTask2"]
task_dict = {
    tasks[0]: romo_parameters,
    tasks[1]: romo_parameters,
    tasks[2]: dm_parameters,
    tasks[3]: dm_parameters,
    tasks[4]: ctx_parameters,
    tasks[5]: ctx_parameters,
}
Task = MultyReduceTasks(
    tasks=task_dict, batch_size=batch_size, delay_between=0, enable_fixation_delay=True
)

print("Task parameters:")
for key in task_dict:
    print(f"{key}:\n{task_dict[key]}\n")

print(f"inputs/outputs: {Task.feature_and_act_size[0]}/{Task.feature_and_act_size[1]}")
Task parameters:
RomoTask1:
RomoTaskRandomModParameters(romo=RomoTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=(None, None), delay=0.1, negative_shift_trial_time=0, positive_shift_trial_time=0.2, negative_shift_delay_time=0, positive_shift_delay_time=1.4), n_mods=2)

RomoTask2:
RomoTaskRandomModParameters(romo=RomoTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=(None, None), delay=0.1, negative_shift_trial_time=0, positive_shift_trial_time=0.2, negative_shift_delay_time=0, positive_shift_delay_time=1.4), n_mods=2)

DMTask1:
DMTaskRandomModParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), n_mods=2)

DMTask2:
DMTaskRandomModParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), n_mods=2)

CtxDMTask1:
CtxDMTaskParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), context=None, value=(None, None))

CtxDMTask2:
CtxDMTaskParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), context=None, value=(None, None))

inputs/outputs: 9/3

Пример генерации датасета¶

In [6]:
inputs, t_outputs = Task.dataset(n_trials=1)
for bath in range(min(batch_size, 10)):
    fig = plt.figure(figsize=(15, 3))
    ax1 = fig.add_subplot(131)
    plt.title("Inputs")
    plt.xlabel("$time, ms$")
    plt.ylabel("$Magnitude$")
    for i in range(3):
        plt.plot(inputs[:, bath, i], label=rf"$in_{i + 1}$")
    plt.legend()
    plt.tight_layout()

    ax2 = fig.add_subplot(132)
    plt.title("Task code (context)")
    plt.xlabel("$time, ms$")
    for i in range(3, inputs.shape[-1]):
        plt.plot(inputs[:, bath, i], label=rf"$in_{i + 1}$")
    plt.legend()
    plt.tight_layout()

    ax3 = fig.add_subplot(133)
    plt.title("Target output")
    plt.xlabel("$time, ms$")
    for i in range(t_outputs.shape[-1]):
        plt.plot(t_outputs[:, bath, i], label=rf"$out_{i + 1}$")
    plt.legend()
    plt.tight_layout()
plt.show()
plt.close()
del inputs
del t_outputs

Инициализация сети и выгрузка на decive¶

In [7]:
feature_size, output_size = Task.feature_and_act_size
hidden_size = 450

neuron_parameters = LIFAdExParameters(
    v_th=torch.as_tensor(0.65),
    tau_ada_inv=0.5 + (6 - 0.5) * torch.rand(hidden_size).to(device),
    alpha=100,
    method="super",
    # rho_reset = torch.as_tensor(5)
)
model = SNNlifadex(
    feature_size,
    hidden_size,
    output_size,
    neuron_parameters=neuron_parameters,
    tau_filter_inv=500,
).to(device)

Критерий и функция ошибки¶

In [8]:
learning_rate = 1e-2


class RMSELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

    def forward(self, yhat, y):
        return torch.sqrt(self.mse(yhat, y))


criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Генерация всех эпох¶

Если память не позволяет, то необходимо генерировать каждую эпоху в основном цикле обучения

In [9]:
if False:
    list_inputs = []
    list_t_outputs = []
    for i in tqdm(range(number_of_epochs)):
        temp_input, temp_t_output = Task.dataset()
        temp_input.astype(dtype=np.float16)
        temp_t_output.astype(dtype=np.float16)
        temp_input[:, :, :] += np.random.normal(0, sigma, size=temp_input.shape)
        list_inputs.append(temp_input)
        list_t_outputs.append(temp_t_output)

Основной цикл обучения¶

In [29]:
from cgtasknet.instruments.instrument_accuracy_network import correct_answer
from cgtasknet.net.states import LIFAdExRefracInitState

name = f"Train_dm_and_romo_task_reduce_lif_adex_without_refrac_random_delay_long_a_alpha_{neuron_parameters.alpha}_N_{hidden_size}"
init_state = LIFAdExRefracInitState(batch_size, hidden_size, device=device)
running_loss = 0
for i in tqdm(range(2000)):
    inputs, target_outputs = Task.dataset()
    inputs[:, :, :3] += np.random.normal(0, sigma, size=inputs.shape)
    inputs = torch.from_numpy(inputs).type(torch.float).to(device)
    target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs, _ = model(inputs)

    loss = criterion(outputs, target_outputs)
    loss.backward()
    optimizer.step()

    # print statistics
    running_loss += loss.item()
    if i % 10 == 9:
        with open("log_multy.txt", "a") as f:
            f.write("epoch: {:d} loss: {:0.5f}\n".format(i + 1, running_loss / 10))
        running_loss = 0.0
        with torch.no_grad():
            torch.save(
                model.state_dict(),
                name,
            )
    if i % 10 == 9:

        result = 0
        for j in range(10):
            del inputs
            del target_outputs
            torch.cuda.empty_cache()
            inputs, target_outputs = Task.dataset(1, delay_between=0)
            inputs += np.random.normal(0, 0.01, size=inputs.shape)
            inputs = torch.from_numpy(inputs).type(torch.float).to(device)
            target_outputs = (
                torch.from_numpy(target_outputs).type(torch.float).to(device)
            )
            outputs = model(inputs)[0]
            answers = correct_answer(
                outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
            )
            result += torch.sum(answers).item()

        accuracy = result / batch_size / 10 * 100
        with open("accuracy_multy.txt", "a") as f:
            f.write(f"ecpoch = {i}; correct/all = {accuracy}\n")
    del inputs
    del target_outputs
    torch.cuda.empty_cache()
print("Finished Training")
  0%|          | 1/2000 [00:05<3:08:26,  5.66s/it]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Input In [29], in <module>
     12 optimizer.zero_grad()
     14 # forward + backward + optimize
---> 15 outputs, _ = model(inputs)
     17 loss = criterion(outputs, target_outputs)
     18 loss.backward()

File a:\src\multy_task\env\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File a:\src\multy_task\env\lib\site-packages\cgtasknet-0.0.1-py3.8.egg\cgtasknet\net\lifadex.py:40, in SNNlifadex.forward(self, x, state)
     37 def forward(
     38     self, x: torch.tensor, state: Optional[LIFAdExState] = None
     39 ) -> Tuple[torch.tensor, LIFAdExState]:
---> 40     outputs, states = save_states(x, self.save_states, self.alif, state)
     42     outputs = self.exp_f(outputs)
     43     return (outputs, states)

File a:\src\multy_task\env\lib\site-packages\cgtasknet-0.0.1-py3.8.egg\cgtasknet\net\save_states.py:16, in save_states(x, save_states, layer, state)
     14     outputs = torch.concat(outputs, axis=0)
     15 else:
---> 16     outputs, states = layer(x, state)
     17 return outputs, states

File a:\src\multy_task\env\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File a:\src\multy_task\env\lib\site-packages\norse-0.0.7.post1-py3.8-win-amd64.egg\norse\torch\module\snn.py:317, in SNNRecurrent.forward(self, input_tensor, state)
    310 activation = (
    311     self.activation_sparse
    312     if self.activation_sparse is not None and input_tensor.is_sparse
    313     else self.activation
    314 )
    316 for ts in range(T):
--> 317     out, state = activation(
    318         input_tensor[ts],
    319         state,
    320         self.input_weights,
    321         self.recurrent_weights,
    322         self.p,
    323         self.dt,
    324     )
    325     outputs.append(out)
    327 return torch.stack(outputs), state

File a:\src\multy_task\env\lib\site-packages\norse-0.0.7.post1-py3.8-win-amd64.egg\norse\torch\functional\lif_adex.py:123, in lif_adex_step(input_tensor, state, input_weights, recurrent_weights, p, dt)
    121 dv_leak = p.v_leak - state.v
    122 dv_exp = p.delta_T * torch.exp((state.v - p.v_th) / p.delta_T)
--> 123 dv = dt * p.tau_mem_inv * (dv_leak + dv_exp + state.i - state.a)
    124 v_decayed = state.v + dv
    126 # compute current updates

KeyboardInterrupt: 

Тестирование¶

np.random.normal(0, 0.01, size=(inputs.shape))

In [127]:
result = 0
for j in tqdm(range(100)):
    try:
        del inputs
    except:
        pass
    try:
        del target_outputs
    except:
        pass
    try:
        del outputs
    except:
        pass
    torch.cuda.empty_cache()
    inputs, target_outputs = Task.dataset(1, delay_between=0)
    inputs[:, :, :3] += np.random.normal(0, 0.01, size=inputs[:, :, :3].shape)
    inputs = torch.from_numpy(inputs).type(torch.float).to(device)
    target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
    outputs = model(inputs)[0]
    answers = correct_answer(
        outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
    )
    result += torch.sum(answers).item()

accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
    del inputs
except:
    pass
try:
    del target_outputs
except:
    pass
try:
    del outputs
except:
    pass
torch.cuda.empty_cache()
100%|██████████| 100/100 [03:37<00:00,  2.17s/it]
93.09

Тестирование¶

np.random.normal(0, 0.05, size=(inputs.shape))

In [187]:
result = 0
for j in tqdm(range(100)):
    try:
        del inputs
    except:
        pass
    try:
        del target_outputs
    except:
        pass
    try:
        del outputs
    except:
        pass
    torch.cuda.empty_cache()
    torch.cuda.empty_cache()
    inputs, target_outputs = Task.dataset(1, delay_between=0)
    inputs[:, :, :3] += np.random.normal(0, 0.05, size=inputs[:, :, :3].shape)
    inputs = torch.from_numpy(inputs).type(torch.float).to(device)
    target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
    outputs = model(inputs)[0]
    answers = correct_answer(
        outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
    )
    result += torch.sum(answers).item()
    # del inputs
    # del target_outputs
    # torch.cuda.empty_cache()
accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
    del inputs
except:
    pass
try:
    del target_outputs
except:
    pass
try:
    del outputs
except:
    pass
torch.cuda.empty_cache()
100%|██████████| 100/100 [03:40<00:00,  2.20s/it]
93.36

Тестирование¶

np.random.normal(0, 0.1, size=(inputs.shape))

In [188]:
result = 0
for j in tqdm(range(100)):
    try:
        del inputs
    except:
        pass
    try:
        del target_outputs
    except:
        pass
    try:
        del outputs
    except:
        pass
    torch.cuda.empty_cache()
    inputs, target_outputs = Task.dataset(1, delay_between=0)
    inputs[:, :, :3] += np.random.normal(0, 0.1, size=inputs[:, :, :3].shape)
    inputs = torch.from_numpy(inputs).type(torch.float).to(device)
    target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
    outputs = model(inputs)[0]
    answers = correct_answer(
        outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
    )
    result += torch.sum(answers).item()

accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
    del inputs
except:
    pass
try:
    del target_outputs
except:
    pass
try:
    del outputs
except:
    pass
torch.cuda.empty_cache()
100%|██████████| 100/100 [03:39<00:00,  2.20s/it]
92.45

Тестирование¶

np.random.normal(0, 0.5, size=(inputs.shape))

In [189]:
result = 0
for j in tqdm(range(100)):
    try:
        del inputs
    except:
        pass
    try:
        del target_outputs
    except:
        pass
    try:
        del outputs
    except:
        pass
    torch.cuda.empty_cache()
    inputs, target_outputs = Task.dataset(1, delay_between=0)
    inputs[:, :, :3] += np.random.normal(0, 0.5, size=inputs[:, :, :3].shape)
    inputs = torch.from_numpy(inputs).type(torch.float).to(device)
    target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
    outputs = model(inputs)[0]
    answers = correct_answer(
        outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
    )
    result += torch.sum(answers).item()

accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
    del inputs
except:
    pass
try:
    del target_outputs
except:
    pass
try:
    del outputs
except:
    pass
torch.cuda.empty_cache()
100%|██████████| 100/100 [03:34<00:00,  2.14s/it]
79.13
In [190]:
inputs = 0
outputs = 0
In [191]:
tau_ada_inv_distrib = neuron_parameters.tau_ada_inv.cpu().numpy()
np.save(f"tau_ada_inv_alpha={neuron_parameters.alpha}", tau_ada_inv_distrib)
In [192]:
lines = []
with open("accuracy_multy.txt", "r") as f:
    while line := f.readline():
        lines.append(float(line.split("=")[2].strip()))
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
Input In [192], in <module>
      1 lines = []
----> 2 with open("accuracy_multy.txt", "r") as f:
      3     while line := f.readline():
      4         lines.append(float(line.split("=")[2].strip()))

FileNotFoundError: [Errno 2] No such file or directory: 'accuracy_multy.txt'
In [ ]:
plt.figure(figsize=(8, 5))
plt.plot([*range(9, 2000, 10)], lines, ".", linestyle="--", markersize=5)
plt.ylabel(r"Accuracy%")
plt.xlabel(r"Epochs")
In [ ]: